{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training an Iris classifier using XGBoost\n",
    "\n",
    "In this notebook, we will show how to build a simple classifier trained on the famous Iris data set using XGBoost."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install scikit-learn==0.20.*\n",
    "!pip install xgboost==0.90\n",
    "!pip install pandas\n",
    "!pip install nyoka\n",
    "!pip install pypmml"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the data\n",
    "The Iris dataset is a part of scikit-learn datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import datasets\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pandas as pd\n",
    "\n",
    "iris = datasets.load_iris()\n",
    "target = 'species'\n",
    "features = iris.feature_names\n",
    "iris_df = pd.DataFrame(iris.data, columns=features)\n",
    "iris_df[target] = iris.target\n",
    "\n",
    "X, y = iris_df[features], iris_df[target]\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=123456)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model\n",
    "Build a XGBoost's XGBClassifier model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from xgboost import XGBClassifier\n",
    "\n",
    "pipeline = Pipeline([\n",
    "    ('scaling', StandardScaler()), \n",
    "    ('xgb', XGBClassifier(n_estimators=5, seed=123456))\n",
    "])\n",
    "\n",
    "pipeline.fit(X_train, y_train)\n",
    "print(\"Test data accuracy of the xgb classifier is {:.2f}\".format(pipeline.score(X_test, y_test)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert the model to PMML\n",
    "Now we can convert the model to PMML using nyoka:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nyoka import xgboost_to_pmml\n",
    "\n",
    "xgboost_to_pmml(pipeline, features, target, \"xgb-iris.pmml\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validate the PMML\n",
    "Validate whether the predictions of PMML are the same as ones produced by the Python model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pypmml import Model\n",
    "import numpy as np\n",
    "\n",
    "model = Model.fromFile(\"xgb-iris.pmml\")\n",
    "result = model.predict(X_test)\n",
    "result.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make predictions using the Python model\n",
    "predictions = pipeline.predict(X_test)\n",
    "probabilities = pipeline.predict_proba(X_test)\n",
    "\n",
    "# Compare results\n",
    "np.testing.assert_almost_equal(result['predicted_species'], predictions)\n",
    "np.testing.assert_array_almost_equal(result[['probability_0', 'probability_1', 'probability_2']], probabilities, 0.001)\n",
    "print(\"The results of PMML are right.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
